{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Watermark Analysis\n",
    "\n",
    "Notebook for performing analysis and visualization of the effects of watermarking schemes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the processed dataset/frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
    "# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
    "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\" # in figure\n",
    "\n",
    "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
    "\n",
    "save_dir = f\"input/{save_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = load_from_disk(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### convert to pandas df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = raw_data.to_pandas()\n",
    "\n",
    "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
    "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
    "\n",
    "orig_len = len(df)\n",
    "\n",
    "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
    "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
    "\n",
    "orig_len = len(df)\n",
    "# df = df[df[\"no_bl_ppl\"].isna()]\n",
    "# df = df[df[\"w_bl_ppl\"].isna()]\n",
    "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
    "\n",
    "orig_len = len(df)\n",
    "\n",
    "df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
    "\n",
    "\n",
    "orig_len = len(df)\n",
    "\n",
    "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
    "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
    "\n",
    "\n",
    "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
    "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)\n",
    "\n",
    "\n",
    "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
    "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor\n",
    "\n",
    "\n",
    "df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
    "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
    "df[\"gamma\"] = df[\"gamma\"].round(3)\n",
    "\n",
    "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "\n",
    "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
    "\n",
    "if \"real_completion_length\":\n",
    "    df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
    "\n",
    "if \"actual_attacked_ratio\" in df.columns:\n",
    "    df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
    "\n",
    "\n",
    "\n",
    "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
    "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
    "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter for the generation lengths we want to look at"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "\n",
    "upper_T = 205\n",
    "lower_T = 195\n",
    "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
    "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
    "\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Add z-scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "import scipy.stats\n",
    "def compute_z_score(observed_wl_frac, T, gamma):\n",
    "    numer = observed_wl_frac - gamma\n",
    "    denom = sqrt(gamma*(1-gamma)/T)\n",
    "    z = numer/denom\n",
    "    return z\n",
    "\n",
    "def compute_wl_for_z(z, T, gamma):\n",
    "    denom = sqrt(gamma*(1-gamma)/T)\n",
    "    numer = ((z*denom)+gamma)*T\n",
    "    return numer\n",
    "\n",
    "def compute_p_value(z):\n",
    "    p_value = scipy.stats.norm.sf(abs(z))\n",
    "    return p_value\n",
    "\n",
    "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "\n",
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
    "    df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # if attacked in df\n",
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
    "    df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "\n",
    "    df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "\n",
    "    df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Groupby"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
    "    groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
    "else:\n",
    "    groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
    "    # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
    "    # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Main groupby"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_df = df.groupby(groupby_fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of rows after filtering: {len(df)}\")\n",
    "print(f\"Number of groups: {len(grouped_df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loop to compute confusion matrix at some z scores for tabulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "\n",
    "def reject_null_hypo(z_score=None,cuttoff=None):\n",
    "    return z_score > cuttoff\n",
    "\n",
    "records = []\n",
    "\n",
    "for group_params in tqdm(list(grouped_df.groups.keys())):\n",
    "    sub_df = grouped_df.get_group(group_params)\n",
    "    grp_size = len(sub_df)\n",
    "\n",
    "    # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
    "    # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
    "    # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
    "\n",
    "    # baseline_labels = np.zeros_like(baseline_z_scores)\n",
    "    # attacked_labels = np.ones_like(w_bl_z_scores)\n",
    "    # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
    "\n",
    "    # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "    # roc_auc = metrics.auc(fpr, tpr)\n",
    "    record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
    "\n",
    "    for thresh in [4.0,5.0]:\n",
    "        \n",
    "        record[\"count\"] = grp_size\n",
    "        record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
    "        record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "        record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "\n",
    "        if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
    "            record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "            record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "\n",
    "    records.append(record)\n",
    "\n",
    "    #     # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
    "    #     # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
    "\n",
    "\n",
    "roc_df = pd.DataFrame.from_records(records)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# thresh = 6.0\n",
    "# thresh = 5.0\n",
    "std_threshes = [4.0, 5.0] #, 6.0]\n",
    "# std_threshes = [4.0]\n",
    "\n",
    "# roc_df[\"params\"] = roc_df.index.to_list()\n",
    "\n",
    "columns = [\"delta\", \"gamma\", \"count\"]\n",
    "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
    "\n",
    "for thresh in std_threshes:\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
    "\n",
    "\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "    \n",
    "    if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
    "        columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "        columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
    "    else:\n",
    "        columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "\n",
    "# filter ot not\n",
    "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 10.0))  &  ((roc_df[\"gamma\"] == 0.1) | (roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
    "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5)  | (roc_df[\"replace_ratio\"] == 0.7)]\n",
    "# sub_df = roc_df\n",
    "\n",
    "sub_df.sort_values(\"delta\")[columns]\n",
    "# sub_df.sort_values(\"num_beams\")[columns]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
    "\n",
    "print(sub_df.sort_values(\"delta\")[columns].round(3).to_latex(index=False))\n",
    "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### write to csv maybe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cols_to_drop = ['no_bl_gen_time',\n",
    "#     'w_bl_gen_time', 'spike_entropies', \n",
    "#     'no_bl_sec_per_tok', 'no_bl_tok_per_sec', 'w_bl_sec_per_tok',\n",
    "#     'w_bl_tok_per_sec', 'baseline_loss','no_bl_loss',\n",
    "#     'w_bl_loss',  'model_name', 'dataset_name',\n",
    "#     'dataset_config_name', 'shuffle_dataset', 'shuffle_seed',\n",
    "#     'shuffle_buffer_size', 'max_new_tokens', 'min_prompt_tokens',\n",
    "#     'limit_indices', 'input_truncation_strategy',\n",
    "#     'input_filtering_strategy', 'output_filtering_strategy', 'initial_seed',\n",
    "#     'dynamic_seed','no_repeat_ngram_size', 'early_stopping',\n",
    "#     'oracle_model_name', 'no_wandb', 'wandb_project', 'wandb_entity', 'output_dir', 'load_prev_generations', 'store_bl_ids',\n",
    "#     'store_spike_ents',  'generate_only',\n",
    "#     'SLURM_JOB_ID', 'SLURM_ARRAY_JOB_ID', 'SLURM_ARRAY_TASK_ID',\n",
    "#     'gen_table_already_existed', 'baseline_num_toks_gend_eq_0',\n",
    "#     'baseline_hit_list', 'no_bl_num_toks_gend_eq_0',\n",
    "#     'no_bl_hit_list', 'w_bl_num_toks_gend_eq_0', 'w_bl_hit_list']\n",
    "# df.drop(cols_to_drop,axis=1).to_csv(\"input/for_poking.csv\")\n",
    "# df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract examples (actual text) for tabulation based on entropy and z scores (tables 1,3,4,5,6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"groupby legend: {groupby_fields}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "groups = [\n",
    "    (True, 1, 2.0, 0.5),\n",
    "    # (True, 1, 10.0, 0.5),\n",
    "    # (False, 8, 2.0, 0.5),\n",
    "    # (False, 8, 10.0, 0.5),\n",
    "]\n",
    "group_dfs = []\n",
    "for group in groups:\n",
    "    sub_df = grouped_df.get_group(group)\n",
    "    group_dfs.append(sub_df)\n",
    "\n",
    "subset_df = pd.concat(group_dfs,axis=0)\n",
    "\n",
    "print(len(subset_df))\n",
    "# subset_df\n",
    "\n",
    "# cols_to_tabulate = groupby_fields + [\n",
    "cols_to_tabulate = [\n",
    "    'idx', \n",
    "    'truncated_input', \n",
    "    # 'prompt_length',\n",
    "    'baseline_completion',\n",
    "    'no_bl_output', \n",
    "    'w_bl_output', \n",
    "    # 'real_completion_length',\n",
    "    # 'no_bl_num_tokens_generated',\n",
    "    # 'w_bl_num_tokens_generated',\n",
    "    'avg_spike_entropy',\n",
    "    # 'baseline_whitelist_fraction',\n",
    "    # 'no_bl_whitelist_fraction',\n",
    "    # 'w_bl_whitelist_fraction',\n",
    "    # 'baseline_z_score',\n",
    "    'no_bl_z_score',\n",
    "    'w_bl_z_score',\n",
    "    # 'baseline_ppl',\n",
    "    'no_bl_ppl',\n",
    "    'w_bl_ppl'\n",
    "]\n",
    "\n",
    "# subset_df[cols_to_tabulate][\"idx\"].value_counts()\n",
    "\n",
    "for idx,occurrences in subset_df[\"idx\"].value_counts().to_dict().items():\n",
    "    subset_df.loc[(subset_df[\"idx\"]==idx),\"occurences\"] = occurrences\n",
    "\n",
    "subset_df[\"occurences\"] = subset_df[\"occurences\"].apply(int)\n",
    "\n",
    "# cols_to_tabulate = [\"occurences\"] + cols_to_tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# subset_df[cols_to_tabulate].sort_values([\"occurences\", \"idx\"],ascending=False)\n",
    "# subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_prompt_chars = 200\n",
    "max_output_chars = 200\n",
    "# subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"[...]{s[-max_prompt_chars:]}\")\n",
    "# subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
    "# subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
    "# subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
    "\n",
    "# if you dont have the indexx you cant start with brackets\n",
    "subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"(...){s[-max_prompt_chars:]}\")\n",
    "subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
    "subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
    "subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slice_size = 2\n",
    "\n",
    "# subset_df[cols_to_tabulate][\"avg_spike_entropy\"].describe()[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_examples = len(subset_df)\n",
    "midpt = num_examples//5\n",
    "lower = midpt - (slice_size//2)\n",
    "upper = midpt + (slice_size//2)+1\n",
    "\n",
    "high_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
    "mid_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
    "low_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)\n",
    "\n",
    "num_examples = len(subset_df)\n",
    "midpt = num_examples//65\n",
    "lower = midpt - (slice_size//2)\n",
    "upper = midpt + (slice_size//2)+1\n",
    "\n",
    "high_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).tail(slice_size)\n",
    "mid_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).iloc[lower:upper]\n",
    "low_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).head(slice_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# high_entropy_examples.head()\n",
    "high_z_examples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mid_entropy_examples.head()\n",
    "mid_z_examples.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# low_entropy_examples.head()\n",
    "low_z_examples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# slices_set_df = pd.concat([high_entropy_examples,low_entropy_examples],axis=0)\n",
    "slices_set_df = pd.concat([high_z_examples,low_z_examples],axis=0).sort_values(\"w_bl_z_score\",ascending=False)\n",
    "slices_set_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# slices_set_df.T.iloc[:,0:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(slices_set_df.to_latex(index=False))\n",
    "# print(low_entropy_examples.to_latex(index=False))\n",
    "# print(mid_entropy_examples.to_latex(index=False))\n",
    "# print(high_entropy_examples.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for c,t in zip(low_entropy_examples.columns,low_entropy_examples.dtypes):\n",
    "#     if t==object:\n",
    "#         low_entropy_examples[c] = low_entropy_examples[c].apply(lambda s: f\"{s[:100]}[...truncated]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# low_entropy_examples.T.to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_to_write = high_entropy_examples\n",
    "# df_to_write = mid_entropy_examples\n",
    "# df_to_write = low_entropy_examples\n",
    "# df_to_write = high_z_examples\n",
    "# df_to_write = mid_z_examples\n",
    "# df_to_write = low_z_examples\n",
    "\n",
    "cols_to_drop = [\"idx\", \"avg_spike_entropy\", \"no_bl_z_score\"] #, \"no_bl_ppl\", \"w_bl_ppl\"]\n",
    "df_to_write = slices_set_df.drop(cols_to_drop,axis=1)\n",
    "\n",
    "\n",
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    column_format=\"\".join([(r'p{3cm}|' if t==object else r'p{0.4cm}|') for c,t in zip(df_to_write.columns,df_to_write.dtypes)])[:-1]\n",
    "    # low_entropy_examples.round(2).to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),column_format=column_format,index=False)\n",
    "    latex_str = df_to_write.round(2).to_latex(column_format=column_format,index=False)\n",
    "\n",
    "print(latex_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# column_format=\"\".join([r'p{2cm}|' for c in low_entropy_examples.columns])\n",
    "# column_format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# low_entropy_examples.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with pd.option_context(\"max_colwidth\", 1000):\n",
    "    print(grouped_df.get_group((True, 1, 2.0, 0.9)).head(10)[\"w_bl_output\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up data for charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# viz_df = pd.DataFrame()\n",
    "\n",
    "# # set the hparam keys, including an indiv column for each you want to ablate on\n",
    "# viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
    "# for i,key in enumerate(groupby_fields):\n",
    "#     viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
    "\n",
    "# describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
    "# viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "# describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
    "# viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "\n",
    "# describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
    "# viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "# describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
    "# viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "\n",
    "# describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
    "# viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "# describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
    "# viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "# describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
    "# viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "# viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "# print(f\"groupby legend: {groupby_fields}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # filtering\n",
    "\n",
    "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
    "\n",
    "# # viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
    "\n",
    "\n",
    "# # fix one of the bl params for analytic chart\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
    "\n",
    "# # viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
    "\n",
    "# # viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
    "\n",
    "# # viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
    "\n",
    "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
    "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
    "\n",
    "# print(len(viz_df))\n",
    "\n",
    "# viz_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize the WL/BL hits via highlighting in html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = 75\n",
    "# # idx = 62\n",
    "\n",
    "# # debug\n",
    "# # idx = 7\n",
    "# # idx = 18\n",
    "# # idx = 231\n",
    "\n",
    "# print(gen_table_w_bl_stats[idx])\n",
    "# print(f\"\\nPrompt:\",gen_table_w_bl_stats[idx][\"truncated_input\"])\n",
    "# print(f\"\\nBaseline (real text):{gen_table_w_bl_stats[idx]['baseline_completion']}\")\n",
    "# print(f\"\\nNo Blacklist:{gen_table_w_bl_stats[idx]['no_bl_output']}\")\n",
    "# print(f\"\\nw/ Blacklist:{gen_table_w_bl_stats[idx]['w_bl_output']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from ipymarkup import show_span_box_markup, get_span_box_markup\n",
    "# from ipymarkup.palette import palette, RED, GREEN, BLUE\n",
    "\n",
    "# from IPython.display import display, HTML\n",
    "\n",
    "# from transformers import GPT2TokenizerFast\n",
    "# # fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
    "# fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"facebook/opt-2.7b\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %autoreload\n",
    "\n",
    "# vis_bl = partial(\n",
    "#     compute_bl_metrics,\n",
    "#     tokenizer=fast_tokenizer,\n",
    "#     hf_model_name=gen_table_meta[\"model_name\"],\n",
    "#     initial_seed=gen_table_meta[\"initial_seed\"],\n",
    "#     dynamic_seed=gen_table_meta[\"dynamic_seed\"],\n",
    "#     bl_proportion=gen_table_meta[\"bl_proportion\"],\n",
    "#     record_hits = True,\n",
    "#     use_cuda=True, # this is obvi critical to match the pseudorandomness\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# stats = vis_bl(gen_table_w_bl_stats[idx], 0)\n",
    "\n",
    "# baseline_hit_list = stats[\"baseline_hit_list\"]\n",
    "# no_bl_hit_list = stats[\"no_bl_hit_list\"]\n",
    "# w_bl_hit_list = stats[\"w_bl_hit_list\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# text = stats[\"truncated_input\"]\n",
    "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
    "# hit_list = baseline_hit_list\n",
    "\n",
    "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
    "# charspans = [cs for cs in charspans if cs is not None]\n",
    "# # spans = [(cs.start,cs.end, \"PR\") for i,cs in enumerate(charspans)]\n",
    "# spans = []\n",
    "\n",
    "# html = get_span_box_markup(text, spans, palette=palette(PR=BLUE), background='white', text_color=\"black\")\n",
    "\n",
    "\n",
    "# with open(\"figs/prompt_html.html\", \"w\") as f:\n",
    "#     f.write(HTML(html).data)\n",
    "\n",
    "# HTML(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# text = stats[\"baseline_completion\"]\n",
    "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
    "# hit_list = baseline_hit_list\n",
    "\n",
    "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
    "# charspans = [cs for cs in charspans if cs is not None]\n",
    "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
    "\n",
    "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
    "\n",
    "\n",
    "# with open(\"figs/baseline_html.html\", \"w\") as f:\n",
    "#     f.write(HTML(html).data)\n",
    "\n",
    "# HTML(html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# text = stats[\"no_bl_output\"]\n",
    "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
    "# hit_list = no_bl_hit_list\n",
    "\n",
    "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
    "# charspans = [cs for cs in charspans if cs is not None]\n",
    "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
    "\n",
    "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
    "\n",
    "\n",
    "# with open(\"figs/no_bl_html.html\", \"w\") as f:\n",
    "#     f.write(HTML(html).data)\n",
    "\n",
    "# HTML(html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# text = stats[\"w_bl_output\"]\n",
    "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
    "# hit_list = w_bl_hit_list\n",
    "\n",
    "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
    "# charspans = [cs for cs in charspans if cs is not None]\n",
    "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
    "\n",
    "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
    "\n",
    "\n",
    "# with open(\"figs/w_bl_html.html\", \"w\") as f:\n",
    "#     f.write(HTML(html).data)\n",
    "\n",
    "# HTML(html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
